# Copyright (c) 2023 Copyright holder of the paper "Revisiting Image Classifier Training for Improved Certified Robust Defense against Adversarial Patches" submitted to TMLR for review

# All rights reserved.

import argparse
import numpy as np
import math
import random
import time
import pandas as pd

import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import timm

from build import generate_masks
from utils import *

# python certify.py -b 64 -c adv_maskset3x3 --pretrained-model resnetv2 --dataset imagenet -k 3 --certi-pixels-percent 3.0

pretrained_model_options = ["resnetv2", "vit_base", "convnext"]
datasets = ["imagenet", "cifar10", "cifar100", "imagenette", 'svhn']

parser = argparse.ArgumentParser(description='PatchCleanser Certification')
parser.add_argument('-j', '--workers', default=6, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=64, type=int,
                    metavar='N')
parser.add_argument('-p', '--print-freq', default=1000, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--seed', default=2022, type=int,
                    help='seed for initializing training. ')
parser.add_argument('-c', '--checkpoint-name', type=str, help='checkpoint name.')
parser.add_argument('--pretrained-model', type=str, default="resnetv2",
                    choices=pretrained_model_options)
parser.add_argument('--dataset', metavar='dataset', default='imagenet',
                    choices=datasets)
parser.add_argument('-k', '--num-mask-locations', default=3, type=int, metavar='N')  # e.g. 3x3 = 9 total mask patches
parser.add_argument('--certi-pixels-percent', default=3.0, type=float, metavar='N')

args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def validate_clean(val_loader, classifier):
    classifier.eval()
    top1 = AverageMeter('Clean Acc@1', ':6.2f')
    eval_start_time = time.time()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            images, target = data[0].cuda(), data[1].cuda()
            output = classifier(normalize(images))
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], target.size(0))
        print('Clean Acc@1 {top1.avg:.3f} Time {Time:.3f} secs'.format(top1=top1, Time=time.time() - eval_start_time))
    return round(top1.avg.item(), 2)


def certification(val_loader, classifier, masks_set, img_size):
    classifier.eval()

    top1 = AverageMeter('Certification Acc@1', ':6.2f')
    eval_time = AverageMeter('Time', ':6.3f')
    progress = ProgressMeter(len(val_loader), [top1, eval_time])

    # get all the two-mask combination set
    masks_set = 1 - masks_set
    mask_set_size = masks_set.shape[1]
    num_combinations = np.sum(range(mask_set_size + 1))
    masks_set_combinations = torch.zeros(1, num_combinations, img_size, img_size)
    count = 0
    for mask1_idx in range(mask_set_size):
        for mask2_idx in range(mask1_idx, mask_set_size):
            masks_set_combinations[:, count] = 1 - masks_set[:, mask1_idx] * masks_set[:, mask2_idx]
            count += 1
    masks_set_combinations_view = masks_set_combinations.unsqueeze(dim=2).to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            clean_images, target = data[0].to(device), data[1].to(device)
            batch_size = clean_images.shape[0]

            prediction_set = torch.zeros(batch_size, num_combinations)

            for mask_idx in range(num_combinations):
                masked_images_set = (1 - masks_set_combinations_view[:, mask_idx]) * clean_images + \
                                    masks_set_combinations_view[:, mask_idx] * 0.5
                output = classifier(normalize(masked_images_set))
                pred = torch.argmax(output, dim=1)
                prediction_set[:, mask_idx] = pred.eq(target).float()

            correct = (torch.prod(prediction_set, dim=1)).sum()
            acc = correct.mul_(100.0 / batch_size)
            top1.update(acc, batch_size)

            eval_time.update(time.time() - eval_start_time)
            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Certification Acc@1 {top1.avg:.3f} Time {Time:.3f} secs'.format(
            top1=top1, Time=time.time() - eval_start_time))

    return round(top1.avg.item(), 2)


def get_patch_size(img_size, certi_pixels_percent):
    total_pixels = img_size ** 2
    certi_pixels = (certi_pixels_percent/100) * total_pixels
    patch_size = math.ceil(np.sqrt(certi_pixels))
    return patch_size


def main():
    cudnn.benchmark = True
    args = parser.parse_args()
    if args.seed is not None:
        # set the seed
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    checkpoint_path = set_checkpoint_path(args.checkpoint_name, args.dataset, args.pretrained_model)
    if not os.path.exists(checkpoint_path):
        checkpoint_path = checkpoint_path.removesuffix(".tar")
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint {checkpoint_path} does not exist. Please check the checkpoint and its path.")
        return

    print("Dataset:", args.dataset)
    print("Batchsize:", args.batch_size)
    print("Random seed:", args.seed)
    print("Mask set size:", args.num_mask_locations ** 2)
    print("Pretrained model:", args.pretrained_model)
    print("Checkpoint path:", checkpoint_path)
    print(f"Certification pixels: {args.certi_pixels_percent}%")

    img_size = 224
    adv_patch_size = get_patch_size(img_size, args.certi_pixels_percent)  # 39x39 constitutes 3% pixels of 224x224 image
    stride = math.ceil((img_size - adv_patch_size + 1) / args.num_mask_locations)
    mask_size = adv_patch_size + stride - 1
    masks_set = generate_masks(mask_size, stride, (args.num_mask_locations ** 2), img_size)
    print(f"Certification patch size: {adv_patch_size}x{adv_patch_size}")
    print("\n")

    train_loader, val_loader = get_dataloaders(args.dataset, train_batch_size=args.batch_size, val_batch_size=args.batch_size)
    num_classes = get_num_classes(args.dataset)

    if args.pretrained_model == "resnetv2":
        classifier = timm.create_model('resnetv2_50x1_bit_distilled', pretrained=True)
    elif args.pretrained_model == "vit_base":
        classifier = timm.create_model('vit_base_patch16_224', pretrained=True)
    elif args.pretrained_model == "convnext":
        classifier = timm.create_model('convnext_tiny_in22ft1k', pretrained=True)

    if num_classes != 1000:
        classifier.reset_classifier(num_classes=num_classes)

    classifier.to(device)

    saved_state = torch.load(checkpoint_path)
    classifier.load_state_dict(saved_state["state_dict"])
    # if "epoch" in saved_state.keys():
    saved_epoch = saved_state["epoch"]
    print("Saved Epoch:", saved_epoch)

    if saved_epoch + 1 < 10:
        print("Model has not trained for 10 epochs. Try after the model training finishes.")
        return

    result_path = set_result_path(args.checkpoint_name, args.dataset, args.pretrained_model)
    result_path = os.path.join(result_path, f"k_{args.num_mask_locations}_cert_{args.certi_pixels_percent}.csv")
    if os.path.exists(result_path):
        print("Certification results for this setting already exists. Retrieving the existing results...")
        df = pd.read_csv(result_path)
        print("Clean acc: ", df["clean_acc"].item())
        print("Certification acc: ", df["certification_acc"].item())
        return

    clean_acc = validate_clean(val_loader, classifier)
    certification_acc = certification(val_loader, classifier, masks_set, img_size)

    # saving the results
    results = [
        [clean_acc, certification_acc, args.num_mask_locations, args.certi_pixels_percent],
    ]
    columns = ["clean_acc", "certification_acc", "num_mask_locations", "certi_pixels_percent"]
    df = pd.DataFrame(results, columns=columns)
    df["clean_acc"] = df["clean_acc"].map(float)
    df["certification_acc"] = df["certification_acc"].map(float)
    df["num_mask_locations"] = df["num_mask_locations"].map(int)
    df["certi_pixels_percent"] = df["certi_pixels_percent"].map(float)

    df.to_csv(result_path, index=False)
    print("Result saved to ", result_path)


if __name__ == '__main__':
    main()
